{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.加载数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>分类</th>\n",
       "      <th>内容</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>《青蛇》造型师默认新《红楼梦》额妆抄袭（图） 凡是看过电影《青蛇》的人，都不会忘记青白二蛇的...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>６．１６日剧榜　＜最后的朋友＞　亮最后杀招成功登顶 《最后的朋友》本周的电视剧排行榜单依然只...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>超乎想象的好看《纳尼亚传奇２：凯斯宾王子》 现时资讯如此发达，搜狐电影评审团几乎人人在没有看...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>吴宇森：赤壁大战不会出现在上集 “希望《赤壁》能给你们不一样的感觉。”对于自己刚刚拍完的影片...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>组图：《多情女人痴情男》陈浩民现场耍宝 陈浩民：外面的朋友大家好，现在是搜狐现场直播，欢迎《...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   分类                                                 内容\n",
       "0  娱乐  《青蛇》造型师默认新《红楼梦》额妆抄袭（图） 凡是看过电影《青蛇》的人，都不会忘记青白二蛇的...\n",
       "1  娱乐  ６．１６日剧榜　＜最后的朋友＞　亮最后杀招成功登顶 《最后的朋友》本周的电视剧排行榜单依然只...\n",
       "2  娱乐  超乎想象的好看《纳尼亚传奇２：凯斯宾王子》 现时资讯如此发达，搜狐电影评审团几乎人人在没有看...\n",
       "3  娱乐  吴宇森：赤壁大战不会出现在上集 “希望《赤壁》能给你们不一样的感觉。”对于自己刚刚拍完的影片...\n",
       "4  娱乐  组图：《多情女人痴情男》陈浩民现场耍宝 陈浩民：外面的朋友大家好，现在是搜狐现场直播，欢迎《..."
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "train_df = pd.read_csv('sohu_train.txt', sep='\\t', header=None)\n",
    "train_df.columns = ['分类', '内容']\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "体育 2000\n",
      "健康 2000\n",
      "女人 2000\n",
      "娱乐 2000\n",
      "房地产 2000\n",
      "教育 2000\n",
      "文化 2000\n",
      "新闻 2000\n",
      "旅游 2000\n",
      "汽车 2000\n",
      "科技 2000\n",
      "财经 2000\n"
     ]
    }
   ],
   "source": [
    "for name, group in train_df.groupby(train_df.columns[0]):\n",
    "    print(name,len(group))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "体育 1000\n",
      "健康 1000\n",
      "女人 1000\n",
      "娱乐 1000\n",
      "房地产 1000\n",
      "教育 1000\n",
      "文化 1000\n",
      "新闻 1000\n",
      "旅游 1000\n",
      "汽车 1000\n",
      "科技 1000\n",
      "财经 1000\n"
     ]
    }
   ],
   "source": [
    "test_df = pd.read_csv('sohu_test.txt', sep='\\t', header=None)\n",
    "for name, group in test_df.groupby(test_df.columns[0]):\n",
    "    print(name, len(group))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.分词"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 加载停顿词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "停顿词列表，即变量stopword_list中共有1452个元素\n",
      "停顿词集合，即变量stopword_set中共有1213个元素\n"
     ]
    }
   ],
   "source": [
    "with open('./stopwords.txt', encoding='utf8') as file:\n",
    "    line_list = file.readlines()\n",
    "    stopword_list = [k.strip() for k in line_list]\n",
    "    stopword_set = set(stopword_list)\n",
    "    print('停顿词列表，即变量stopword_list中共有%d个元素' %len(stopword_list))\n",
    "    print('停顿词集合，即变量stopword_set中共有%d个元素' %len(stopword_set))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 使用jieba库制作分词结果列表cutWords_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache C:\\Users\\ADMINI~1\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 0.502 seconds.\n",
      "Prefix dict has been built succesfully.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "前3000篇文章分词共花费12.34秒\n",
      "前6000篇文章分词共花费29.44秒\n",
      "前9000篇文章分词共花费35.39秒\n",
      "前12000篇文章分词共花费43.45秒\n",
      "前15000篇文章分词共花费49.67秒\n",
      "前18000篇文章分词共花费56.70秒\n",
      "前21000篇文章分词共花费69.91秒\n",
      "前24000篇文章分词共花费75.45秒\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "import time\n",
    "\n",
    "cutWords_list = []\n",
    "startTime = time.time()\n",
    "content_series = train_df['内容']\n",
    "for i in range(len(content_series)):\n",
    "    content = content_series.iloc[i]\n",
    "    cutWords = [k for k in jieba.cut(content, True) if k not in stopword_set]\n",
    "    if (i+1) % 3000 == 0:\n",
    "        usedTime = time.time() - startTime\n",
    "        print('前%d篇文章分词共花费%.2f秒' %(i+1, usedTime))\n",
    "    cutWords_list.append(cutWords)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 保存分词结果列表cutWords_list到文本文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "txtFilePath = 'saved_variable/cutWords_list.txt'\n",
    "with open(txtFilePath, 'w', encoding='utf8') as file:\n",
    "    for cutWords in cutWords_list:\n",
    "        file.write(' '.join(cutWords))\n",
    "        file.write('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 从文本文件加载分词结果列表cutWords_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "txtFilePath = 'saved_variable/cutWords_list.txt'\n",
    "with open(txtFilePath, 'r', encoding='utf8') as file:\n",
    "    cutWords_list = [k.split(' ') for k in file.readlines()]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.word2vec模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.1 word2vec模型实例化对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "形成word2vec模型共花费176.25秒\n"
     ]
    }
   ],
   "source": [
    "from gensim.models import Word2Vec\n",
    "startTime = time.time()\n",
    "word2vec_model = Word2Vec(cutWords_list, size=200, iter=10, min_count=20)\n",
    "usedTime = time.time() - startTime\n",
    "print('形成word2vec模型共花费%.2f秒' %usedTime)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.2 通过word2vec对象的most_similar方法获取词义相近的次"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('摄影师', 0.5912884473800659),\n",
       " ('摄影家', 0.5197041034698486),\n",
       " ('摄影展', 0.4974055290222168),\n",
       " ('摄影机', 0.4763668477535248),\n",
       " ('摄影记者', 0.4559651017189026),\n",
       " ('摄影艺术', 0.45154887437820435),\n",
       " ('作曲', 0.44387632608413696),\n",
       " ('摄影奖', 0.4405069649219513),\n",
       " ('摄影棚', 0.43817073106765747),\n",
       " ('人体摄影', 0.4270936846733093)]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2vec_model.wv.most_similar('摄影')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('女士', 0.5280015468597412)]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word2vec_model.most_similar(positive=['女人', '先生'], negative=['男人'], topn=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.3 使用pickle库保存 word2vec模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle \n",
    "\n",
    "pickleFilePath = 'saved_variable/word2vec_model.pickle'\n",
    "with open(pickleFilePath, 'wb') as file:\n",
    "    pickle.dump(word2vec_model, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.4 使用pickle库加载word2vec模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "paramiko missing, opening SSH/SCP/SFTP paths will be disabled.  `pip install paramiko` to suppress\n"
     ]
    }
   ],
   "source": [
    "import pickle \n",
    "\n",
    "pickleFilePath = 'saved_variable/word2vec_model.pickle'\n",
    "with open(pickleFilePath, 'rb') as file:\n",
    "    word2vec_model = pickle.load(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.特征工程"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 每篇文章的内容表示成向量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def get_contentVector(cutWords, word2vec_model):\n",
    "    vector_list = [word2vec_model.wv[k] for k in cutWords if k in word2vec_model]\n",
    "    contentVector = np.array(vector_list).mean(axis=0)\n",
    "    return contentVector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "前3000篇文章内容表示成向量共花费21.11秒\n",
      "前6000篇文章内容表示成向量共花费53.62秒\n",
      "前9000篇文章内容表示成向量共花费64.50秒\n",
      "前12000篇文章内容表示成向量共花费79.25秒\n",
      "前15000篇文章内容表示成向量共花费90.29秒\n",
      "前18000篇文章内容表示成向量共花费102.94秒\n",
      "前21000篇文章内容表示成向量共花费127.81秒\n",
      "前24000篇文章内容表示成向量共花费136.92秒\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "startTime = time.time()\n",
    "contentVector_list = []\n",
    "for i in range(len(cutWords_list)):\n",
    "    cutWords = cutWords_list[i]\n",
    "    if (i+1) % 3000 == 0:\n",
    "        usedTime = time.time() - startTime\n",
    "        print('前%d篇文章内容表示成向量共花费%.2f秒' %(i+1, usedTime))\n",
    "    contentVector_list.append(get_contentVector(cutWords, word2vec_model))\n",
    "X = np.array(contentVector_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 使用ndarray对象的dump方法保存文章向量化结果X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "txtFilePath = 'saved_variable/X.txt'\n",
    "X.dump(txtFilePath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.3 使用numpy库的load方法加载文章向量化结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "txtFilePath = 'saved_variable/X.txt'\n",
    "X = np.load(txtFilePath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 标签编码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "labelEncoder = LabelEncoder()\n",
    "y = labelEncoder.fit_transform(train_df['分类'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 检查特征矩阵和预测目标值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(24000, 200) (24000,)\n"
     ]
    }
   ],
   "source": [
    "print(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 逻辑回归模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8033333333333333"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2)\n",
    "logisticRegression_model = LogisticRegression()\n",
    "logisticRegression_model.fit(train_X, train_y)\n",
    "logisticRegression_model.score(test_X, test_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.4 使用pickle库保存逻辑回归模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "pickleFilePath = 'saved_variable/logisticRegression_model.pickle'\n",
    "with open(pickleFilePath, 'wb') as file:\n",
    "    pickle.dump(logisticRegression_model, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.5 使用pickle库加载逻辑回归模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "pickleFilePath = 'saved_variable/logisticRegression_model.pickle'\n",
    "with open(pickleFilePath, 'rb') as file:\n",
    "    logisticRegression_model = pickle.load(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.模型评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 交叉验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.80041667 0.79833333 0.79645833 0.79041667 0.7975    ]\n",
      "0.7966249999999999\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import ShuffleSplit\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "cv_split = ShuffleSplit(n_splits=5, train_size=0.7, test_size=0.2)\n",
    "score_ndarray = cross_val_score(LogisticRegression(), X, y, cv=cv_split)\n",
    "print(score_ndarray)\n",
    "\n",
    "print(score_ndarray.mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2 混淆矩阵"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.2.1 获取训练集文本内容向量化后的特征矩阵 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:4: DeprecationWarning: Call to deprecated `__contains__` (Method will be removed in 4.0.0, use self.wv.__contains__() instead).\n",
      "  after removing the cwd from sys.path.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集文本内容向量化花费时间119.32秒\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import jieba \n",
    "import time\n",
    "pickleFilePath = 'saved_variable/word2vec_model.pickle'\n",
    "with open(pickleFilePath, 'rb') as file:\n",
    "    word2vec_model = pickle.load(file)\n",
    "def get_featureMatrix(content_series):\n",
    "    vector_list = []\n",
    "    for content in content_series:\n",
    "        vector = get_contentVector(jieba.cut(content, True), word2vec_model)\n",
    "        vector_list.append(vector)\n",
    "    featureMatrix = np.array(vector_list)\n",
    "    return featureMatrix\n",
    "\n",
    "test_df = pd.read_csv('sohu_test.txt', sep='\\t', header=None)\n",
    "test_df.columns = ['分类', '内容']\n",
    "startTime = time.time()\n",
    "featureMatrix = getVectorMatrix(test_df['内容'])\n",
    "usedTime = time.time() - startTime\n",
    "print('测试集文本内容向量化花费时间%.2f秒' %usedTime)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6.2.2 绘制混淆矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>体育</th>\n",
       "      <th>健康</th>\n",
       "      <th>女人</th>\n",
       "      <th>娱乐</th>\n",
       "      <th>房地产</th>\n",
       "      <th>教育</th>\n",
       "      <th>文化</th>\n",
       "      <th>新闻</th>\n",
       "      <th>旅游</th>\n",
       "      <th>汽车</th>\n",
       "      <th>科技</th>\n",
       "      <th>财经</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>体育</th>\n",
       "      <td>968</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>健康</th>\n",
       "      <td>0</td>\n",
       "      <td>827</td>\n",
       "      <td>58</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>14</td>\n",
       "      <td>5</td>\n",
       "      <td>45</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>12</td>\n",
       "      <td>27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>女人</th>\n",
       "      <td>6</td>\n",
       "      <td>39</td>\n",
       "      <td>809</td>\n",
       "      <td>40</td>\n",
       "      <td>4</td>\n",
       "      <td>8</td>\n",
       "      <td>51</td>\n",
       "      <td>13</td>\n",
       "      <td>13</td>\n",
       "      <td>6</td>\n",
       "      <td>7</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>娱乐</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>52</td>\n",
       "      <td>803</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>109</td>\n",
       "      <td>13</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>房地产</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>2</td>\n",
       "      <td>897</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>17</td>\n",
       "      <td>16</td>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>教育</th>\n",
       "      <td>1</td>\n",
       "      <td>14</td>\n",
       "      <td>16</td>\n",
       "      <td>8</td>\n",
       "      <td>2</td>\n",
       "      <td>888</td>\n",
       "      <td>10</td>\n",
       "      <td>39</td>\n",
       "      <td>7</td>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>文化</th>\n",
       "      <td>4</td>\n",
       "      <td>6</td>\n",
       "      <td>62</td>\n",
       "      <td>161</td>\n",
       "      <td>8</td>\n",
       "      <td>16</td>\n",
       "      <td>619</td>\n",
       "      <td>48</td>\n",
       "      <td>37</td>\n",
       "      <td>9</td>\n",
       "      <td>30</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>新闻</th>\n",
       "      <td>13</td>\n",
       "      <td>32</td>\n",
       "      <td>22</td>\n",
       "      <td>15</td>\n",
       "      <td>25</td>\n",
       "      <td>56</td>\n",
       "      <td>55</td>\n",
       "      <td>591</td>\n",
       "      <td>29</td>\n",
       "      <td>10</td>\n",
       "      <td>49</td>\n",
       "      <td>103</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>旅游</th>\n",
       "      <td>4</td>\n",
       "      <td>8</td>\n",
       "      <td>26</td>\n",
       "      <td>6</td>\n",
       "      <td>20</td>\n",
       "      <td>1</td>\n",
       "      <td>36</td>\n",
       "      <td>29</td>\n",
       "      <td>844</td>\n",
       "      <td>8</td>\n",
       "      <td>11</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>汽车</th>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>6</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>13</td>\n",
       "      <td>929</td>\n",
       "      <td>9</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>科技</th>\n",
       "      <td>1</td>\n",
       "      <td>14</td>\n",
       "      <td>13</td>\n",
       "      <td>10</td>\n",
       "      <td>2</td>\n",
       "      <td>12</td>\n",
       "      <td>24</td>\n",
       "      <td>33</td>\n",
       "      <td>15</td>\n",
       "      <td>5</td>\n",
       "      <td>827</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>财经</th>\n",
       "      <td>2</td>\n",
       "      <td>18</td>\n",
       "      <td>18</td>\n",
       "      <td>3</td>\n",
       "      <td>55</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "      <td>85</td>\n",
       "      <td>17</td>\n",
       "      <td>27</td>\n",
       "      <td>67</td>\n",
       "      <td>697</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      体育   健康   女人   娱乐  房地产   教育   文化   新闻   旅游   汽车   科技   财经\n",
       "体育   968    1    5    8    0    2    3    2    8    1    0    2\n",
       "健康     0  827   58    2    3   14    5   45    6    1   12   27\n",
       "女人     6   39  809   40    4    8   51   13   13    6    7    4\n",
       "娱乐     5    0   52  803    0    2  109   13    7    1    8    0\n",
       "房地产    3    0    8    2  897    3    4   17   16    3    7   40\n",
       "教育     1   14   16    8    2  888   10   39    7    1   10    4\n",
       "文化     4    6   62  161    8   16  619   48   37    9   30    0\n",
       "新闻    13   32   22   15   25   56   55  591   29   10   49  103\n",
       "旅游     4    8   26    6   20    1   36   29  844    8   11    7\n",
       "汽车     5    4    4    2    2    6    1    6   13  929    9   19\n",
       "科技     1   14   13   10    2   12   24   33   15    5  827   44\n",
       "财经     2   18   18    3   55    9    2   85   17   27   67  697"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "pickleFilePath = 'saved_variable/logisticRegression_model.pickle'\n",
    "with open(pickleFilePath, 'rb') as file:\n",
    "    logisticRegression_model = pickle.load(file)\n",
    "test_label_list = labelEncoder.transform(test_df['分类'])\n",
    "predict_label_list = logisticRegression_model.predict(featureMatrix)\n",
    "pd.DataFrame(confusion_matrix(test_label_list, predict_label_list), \n",
    "             columns=labelEncoder.classes_,\n",
    "             index=labelEncoder.classes_ )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3 报告表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1</th>\n",
       "      <th>Support</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>体育</td>\n",
       "      <td>0.956522</td>\n",
       "      <td>0.96800</td>\n",
       "      <td>0.962227</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>健康</td>\n",
       "      <td>0.858775</td>\n",
       "      <td>0.82700</td>\n",
       "      <td>0.842588</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>女人</td>\n",
       "      <td>0.740165</td>\n",
       "      <td>0.80900</td>\n",
       "      <td>0.773053</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>娱乐</td>\n",
       "      <td>0.757547</td>\n",
       "      <td>0.80300</td>\n",
       "      <td>0.779612</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>房地产</td>\n",
       "      <td>0.881139</td>\n",
       "      <td>0.89700</td>\n",
       "      <td>0.888999</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>教育</td>\n",
       "      <td>0.873156</td>\n",
       "      <td>0.88800</td>\n",
       "      <td>0.880516</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>文化</td>\n",
       "      <td>0.673558</td>\n",
       "      <td>0.61900</td>\n",
       "      <td>0.645128</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>新闻</td>\n",
       "      <td>0.641694</td>\n",
       "      <td>0.59100</td>\n",
       "      <td>0.615305</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>旅游</td>\n",
       "      <td>0.833992</td>\n",
       "      <td>0.84400</td>\n",
       "      <td>0.838966</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>汽车</td>\n",
       "      <td>0.928072</td>\n",
       "      <td>0.92900</td>\n",
       "      <td>0.928536</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>科技</td>\n",
       "      <td>0.797493</td>\n",
       "      <td>0.82700</td>\n",
       "      <td>0.811978</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>财经</td>\n",
       "      <td>0.736008</td>\n",
       "      <td>0.69700</td>\n",
       "      <td>0.715973</td>\n",
       "      <td>1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>总体</td>\n",
       "      <td>0.806510</td>\n",
       "      <td>0.80825</td>\n",
       "      <td>0.806907</td>\n",
       "      <td>12000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    Label  Precision   Recall        F1  Support\n",
       "0      体育   0.956522  0.96800  0.962227     1000\n",
       "1      健康   0.858775  0.82700  0.842588     1000\n",
       "2      女人   0.740165  0.80900  0.773053     1000\n",
       "3      娱乐   0.757547  0.80300  0.779612     1000\n",
       "4     房地产   0.881139  0.89700  0.888999     1000\n",
       "5      教育   0.873156  0.88800  0.880516     1000\n",
       "6      文化   0.673558  0.61900  0.645128     1000\n",
       "7      新闻   0.641694  0.59100  0.615305     1000\n",
       "8      旅游   0.833992  0.84400  0.838966     1000\n",
       "9      汽车   0.928072  0.92900  0.928536     1000\n",
       "10     科技   0.797493  0.82700  0.811978     1000\n",
       "11     财经   0.736008  0.69700  0.715973     1000\n",
       "999    总体   0.806510  0.80825  0.806907    12000"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "\n",
    "def eval_model(test_label_list, predict_label_list, className_list):\n",
    "    # 计算每个分类的Precision, Recall, f1, support\n",
    "    p, r, f1, s = precision_recall_fscore_support(test_label_list, predict_label_list)\n",
    "    # 计算总体的平均Precision, Recall, f1, support\n",
    "    total_p = np.average(p, weights=s)\n",
    "    total_r = np.average(r, weights=s)\n",
    "    total_f1 = np.average(f1, weights=s)\n",
    "    total_s = np.sum(s)\n",
    "    res1 = pd.DataFrame({\n",
    "        u'Label': className_list,\n",
    "        u'Precision': p,\n",
    "        u'Recall': r,\n",
    "        u'F1': f1,\n",
    "        u'Support': s\n",
    "    })\n",
    "    res2 = pd.DataFrame({\n",
    "        u'Label': ['总体'],\n",
    "        u'Precision': [total_p],\n",
    "        u'Recall': [total_r],\n",
    "        u'F1': [total_f1],\n",
    "        u'Support': [total_s]\n",
    "    })\n",
    "    res2.index = [999]\n",
    "    res = pd.concat([res1, res2])\n",
    "    return res[['Label', 'Precision', 'Recall', 'F1', 'Support']]\n",
    "\n",
    "eval_model(test_label_list, predict_label_list, labelEncoder.classes_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
